import numpy as np
import os
import torch
import torch.jit as jit
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import threading
from threading import Thread

# RNN variants
# from rnn_models.qrnn_2 import QRNNLayer
# from rnn_models.qrnn import QRNN, QRNNLayer2
# from rnn_models.qrnn_jit import JitQRNN
# from rnn_models.mgu import MGU, MGU2
# from rnn_models.mgu_jit import JitMGU, JitMGULayer, JitMGUCell
# from rnn_models.mgu2_jit import JitMGU2, JitMGU2Layer, JitMGU2Cell
# from rnn_models.gru_jit import JitGRU, JitGRULayer, JitGRUCell
# from rnn_models.lmu import LMUCell
from rnn_models.drnn import DRNN

from jit_nn_modules import *
from nn_modules import *

from distributions import Categorical
from utils import init

GRUS = ['gru', 'gru_jit', 'mgu', 'mgu_jit', 'mgu2_jit', 'qrnn', 'qrnn_jit', 'qrnn_2', 'lmu', 'drnn']
LSTMS = ['lstm']


MODEL_DIR = os.path.dirname(__file__)
MODEL_PATH = {
    name: os.path.join(MODEL_DIR, f'resnet_trained_models/{name}_checkpoint_100.tar'
                       ) for name in ['resnet18', 'resnet50']
}


def load_simclr_encoder(proj_dim=64, encoder_name='resnet18'):
    # load pre-trained model from checkpoint
    simclr_model = SimCLR(encoder_name=encoder_name, projection_dim=proj_dim)
    simclr_model.load_state_dict(torch.load(MODEL_PATH[encoder_name],
                                            map_location='cpu'))
    simclr_model.eval()
    return simclr_model, proj_dim


class ThreadWithReturnValue(Thread):
    def __init__(self, group=None, target=None, name=None,
                 args=(), kwargs={}, Verbose=None):
        Thread.__init__(self, group, target, name, args, kwargs)
        self._return = None
    def run(self):
        # print(type(self._target))
        if self._target is not None:
            self._return = self._target(*self._args,
                                                **self._kwargs)
    def join(self, *args):
        Thread.join(self, *args)
        return self._return

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class FCNetwork(nn.Module):
    def __init__(self, dims, out_layer=None):
        """
        Creates a network using ReLUs between layers and no activation at the end
        :param dims: tuple in the form of (100, 100, ..., 5). for dim sizes
        """
        super().__init__()
        input_size = dims[0]
        h_sizes = dims[1:]

        mods = [nn.Linear(input_size, h_sizes[0])]
        for i in range(len(h_sizes) - 1):
            mods.append(nn.ReLU())
            mods.append(nn.Linear(h_sizes[i], h_sizes[i + 1]))

        if out_layer:
            mods.append(out_layer)

        self.layers = nn.Sequential(*mods)

    def forward(self, x):
        # Feedforward
        return self.layers(x)

    def hard_update(self, source):
        for target_param, source_param in zip(self.parameters(), source.parameters()):
            target_param.data.copy_(source_param.data)

    def soft_update(self, source, t):
        for target_param, source_param in zip(self.parameters(), source.parameters()):
            target_param.data.copy_((1 - t) * target_param.data + t * source_param.data)


class Policy(nn.Module):
    def __init__(self, obs_space, action_space, hidden_size = 64, use_spectral_norm = False, use_conv_state_encoder = False, env_name = None, use_norm = False, base=None, base_kwargs=None):
        super(Policy, self).__init__()

        obs_shape = obs_space.shape
        if(use_conv_state_encoder and 'MarlGrid' in env_name):
             obs_shape = obs_space[0].shape

        if base_kwargs is None:
            base_kwargs = {}

        if('MarlGrid' in env_name):
            self.base = MGMLPBase(obs_shape, obs_space[1].shape[0], hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, use_conv_state_encoder = use_conv_state_encoder, use_norm = use_norm, **base_kwargs)
        else:
            self.base = MLPBase(obs_shape[0], hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, use_conv_state_encoder = use_conv_state_encoder, **base_kwargs)
        # self.base = torch.jit.script(self.base)

        num_outputs = action_space.n
        self.dist = Categorical(self.base.output_size, num_outputs)

    @property
    def is_recurrent(self):
        return self.base.is_recurrent

    @property
    def recurrent_hidden_state_size(self):
        """Size of rnn_hx."""
        return self.base.recurrent_hidden_state_size

    def forward(self, inputs, rnn_hxs, masks):
        raise NotImplementedError

    #@profile
    def act(self, inputs, rnn_hxs, masks, deterministic=False):
        value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks)
        dist = self.dist(actor_features)

        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        return value, action, action_log_probs, rnn_hxs

    def get_value(self, inputs, rnn_hxs, masks):
        value, _, _ = self.base(inputs, rnn_hxs, masks)
        return value

    #@profile
    def evaluate_actions(self, inputs, rnn_hxs, masks, action):
        value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks)
        dist = self.dist(actor_features)

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        return value, action_log_probs, dist_entropy, rnn_hxs, dist.probs

para_act_lock = threading.Lock()
class CommPolicy(nn.Module):
# class CommPolicy(jit.ScriptModule):
    def __init__(self, obs_space, action_space, message_space, use_comm_gate = False, use_memory = False, use_aligner = False, aligner_type = 'obs', mem_size = 0, mem_type = "SDND", mem_key_size = 6, hidden_size = 64, use_comm_sep_rnn = False, use_spectral_norm = False, use_dial = False, use_conv_state_encoder = False, num_agents = 4, use_projector= False, env_name = None, use_norm = False, base=None, base_kwargs=None):
        super(CommPolicy, self).__init__()

        if(use_conv_state_encoder == False):
            obs_shape = obs_space.shape
        elif(use_conv_state_encoder and 'MarlGrid' in env_name):
            obs_shape = obs_space[0].shape
        else:
            obs_shape = [64]

        if base_kwargs is None:
            base_kwargs = {}

        if(use_comm_gate):
            self.base = CommGatedMLPBase(obs_shape[0], message_space, hidden_size= hidden_size, use_spectral_norm = use_spectral_norm, **base_kwargs)
        else:
            # Whether to use an external memory module - no longer supported!!!!!
            if(use_memory):
                if(mem_type == "SDND"):
                    self.base = CommSDNDMLPBase(obs_shape[0], message_space, key_size = mem_key_size, hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, **base_kwargs)
                elif("LinMem"):
                    self.base = CommLinearMemMLPBase(obs_shape[0], message_space, mem_size, hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, **base_kwargs)
                else:
                    raise "No support for this type of memory module"
            else:
                # Whether to use a separate RNN for communication
                if(use_comm_sep_rnn):
                    self.base = CommMLPBase_CSRNN(obs_shape[0], message_space, hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, **base_kwargs)
                else:
                    # self.base = CommMLPBase(obs_shape[0], message_space, hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, **base_kwargs)
                    if('MarlGrid' in env_name):
                        self.base = CommMGBase(obs_shape, obs_space[1].shape[0], message_space, hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, use_dial = use_dial, use_conv_state_encoder = use_conv_state_encoder, aligner_type = aligner_type, use_norm = use_norm, **base_kwargs)
                    else:
                        self.base = CommAltMLPBase(None if use_conv_state_encoder else obs_shape[0], message_space, hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, use_dial = use_dial, use_conv_state_encoder = use_conv_state_encoder, aligner_type = aligner_type, **base_kwargs)
                # self.base = CommAltMLPBase(obs_shape[0], message_space, hidden_size = hidden_size, **base_kwargs)
        self.use_comm_sep_rnn = use_comm_sep_rnn
        self.use_comm_gate = use_comm_gate
        self.use_memory = use_memory
        self.use_aligner = use_aligner
        self.use_projector = use_projector
        self.aligner_type = aligner_type
        if(self.use_aligner):
            if(self.use_projector):
                # Only used during learning
                self.projector = CommHeadProjector(message_space[2])
            if('mi' in self.aligner_type):
                # self.mi_aligner = CommMIHeadAligner(message_space[2], num_agents, hidden_size = hidden_size)
                if(self.aligner_type == 'mi_obs'):
                    # This reconstructs raw obs
                    # self.aligner = CommHeadAligner(obs_shape[0], hidden_size, message_space[2]) if('MarlGrid' not in env_name) else CommHeadAligner(hidden_size + self.base.num_direct_features, hidden_size, message_space[2])
                    # This reconstructs encoded obs
                    self.aligner = CommHeadAligner(hidden_size, hidden_size, message_space[2])
                elif(self.aligner_type == 'mi_act' or self.aligner_type == 'mi_kl_act'):
                    self.aligner = CommActHeadAligner(message_space[2], hidden_size = hidden_size)
                    self.aligner_act_dist = Categorical(self.base.output_size, action_space.n)
                elif(self.aligner_type == 'mi_obs_act' or self.aligner_type == 'mi_obs_kl_act'):
                    self.o_aligner = CommHeadAligner(obs_shape[0], hidden_size, message_space[2]) if('MarlGrid' not in env_name) else CommHeadAligner(hidden_size + self.base.num_direct_features, hidden_size, message_space[2])
                    # self.o_aligner = CommHeadAligner(hidden_size, hidden_size, message_space[2])
                    self.a_aligner = CommActHeadAligner(message_space[2], hidden_size = hidden_size)
                    self.aligner_act_dist = Categorical(self.base.output_size, action_space.n)
            elif('mm' in self.aligner_type):
                self.aligner = CommMMHeadAligner(message_space[2], num_agents, hidden_size)
            else:
                if(self.aligner_type == 'obs'):
                    # This reconstructs raw obs
                    # self.aligner = CommHeadAligner(obs_shape[0], hidden_size, message_space[2]) if('MarlGrid' not in env_name) else CommHeadAligner(hidden_size + self.base.num_direct_features, hidden_size, message_space[2])
                    # This reconstructs encoded obs - for predator and prey (pp)
                    self.aligner = CommHeadAligner(hidden_size, hidden_size, message_space[2])
                elif(self.aligner_type == 'act' or self.aligner_type == 'kl_act'):
                    self.aligner = CommActHeadAligner(message_space[2], hidden_size = hidden_size) if('MarlGrid' not in env_name) else CommHeadAligner(hidden_size + self.base.num_direct_features, hidden_size, message_space[2])
                    self.aligner_act_dist = Categorical(self.base.output_size, action_space.n)
                elif(self.aligner_type == 'obs_act' or self.aligner_type == 'obs_kl_act'):
                    # self.o_aligner = CommHeadAligner(obs_shape[0], hidden_size, message_space[2]) if('MarlGrid' not in env_name) else CommHeadAligner(hidden_size + self.base.num_direct_features, hidden_size, message_space[2])
                    self.o_aligner = CommHeadAligner(hidden_size, hidden_size, message_space[2])
                    self.a_aligner = CommActHeadAligner(message_space[2], hidden_size = hidden_size)
                    self.aligner_act_dist = Categorical(self.base.output_size, action_space.n)
                else:
                    raise NotImplementedError


        # sample_input = torch.zeros(1, obs_shape[0])
        # JIT tracing not much difference
        # sample_message = torch.zeros(1, message_space[0])
        # sample_recurrent_hidden = torch.zeros(self.base._hidden_size)
        # sample_mask = torch.zeros(1, 1)
        # self.base = torch.jit.trace(self.base, sample_input, sample_message, sample_recurrent_hidden, sample_mask)

        # self.base = torch.jit.script(self.base)
        num_outputs = action_space.n
        self.dist = Categorical(self.base.output_size, num_outputs)

    @property
    def is_recurrent(self):
        return self.base.is_recurrent

    @property
    def recurrent_hidden_state_size(self):
        """Size of rnn_hx."""
        return self.base.recurrent_hidden_state_size

    def forward(self, inputs, rnn_hxs, masks):
        raise NotImplementedError

    #@profile
    def act(self, inputs, msg_inputs, rnn_hxs, masks, comm_partial_with_grad = False, deterministic=False, crnn_hxs = None):
        if(self.use_comm_sep_rnn):
            value, actor_features, messages, rnn_hxs, crnn_hxs = self.base(inputs, msg_inputs, rnn_hxs, crnn_hxs, masks, comm_partial_with_grad)
        else:
            value, actor_features, messages, rnn_hxs = self.base(inputs, msg_inputs, rnn_hxs, masks, comm_partial_with_grad)
        with torch.set_grad_enabled(not comm_partial_with_grad):
            dist = self.dist(actor_features)
            if deterministic:
                action = dist.mode()
            else:
                action = dist.sample()

            action_log_probs = dist.log_probs(action)
            dist_entropy = dist.entropy().mean()

        # Only return the first item for messages, the second item is for alignment which doesn't backpropagate across agents
        return value, action, action_log_probs, messages[0], rnn_hxs, crnn_hxs if self.use_comm_sep_rnn else None

    # def para_act(self, agent_idx, para_variables, inputs, msg_inputs, rnn_hxs, masks, comm_partial_with_grad = False, deterministic=False):
    #     value, actor_features, messages, rnn_hxs = self.base(inputs, msg_inputs, rnn_hxs, masks, comm_partial_with_grad)

    #     with torch.set_grad_enabled(not comm_partial_with_grad):
    #         dist = self.dist(actor_features)

    #         if deterministic:
    #             action = dist.mode()
    #         else:
    #             action = dist.sample()

    #         action_log_probs = dist.log_probs(action)
    #         dist_entropy = dist.entropy().mean()

    #     n_value, n_action, n_action_log_prob, n_messages, n_recurrent_hidden_states = para_variables

    #     # Assign to tensor accordingly
    #     with para_act_lock:
    #         n_value[agent_idx] = value
    #         n_action[agent_idx] = action
    #         n_action_log_prob[agent_idx] = action_log_probs
    #         n_messages[agent_idx] = messages
    #         n_recurrent_hidden_states[agent_idx] = rnn_hxs

    def get_value(self, inputs, msg_inputs, rnn_hxs, masks, crnn_hxs = None):
        if(self.use_comm_sep_rnn):
            value, _, _, _, _= self.base(inputs, msg_inputs, rnn_hxs, crnn_hxs, masks)
        else:
            value, _, _, _= self.base(inputs, msg_inputs, rnn_hxs, masks)
        return value

    #@profile
    def evaluate_actions(self, inputs, msg_inputs, rnn_hxs, masks, action, crnn_hxs = None):
        if(self.use_comm_sep_rnn):
            value, actor_features, messages, rnn_hxs, crnn_hxs = self.base(inputs, msg_inputs, rnn_hxs, crnn_hxs, masks)
        else:
            value, actor_features, messages, rnn_hxs = self.base(inputs, msg_inputs, rnn_hxs, masks)
        dist = self.dist(actor_features)

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        aligned_output = None
        if(self.use_aligner):
            # Message copy that doesn't backpropagate to other agents
            if(self.use_projector):
                g_messages = self.projector(messages[1])
            else:
                g_messages = messages[1]
            if(self.aligner_type == 'obs' or self.aligner_type == 'mi_obs'):
                aligned_output = self.aligner(g_messages)
            elif(self.aligner_type == 'act' or self.aligner_type == 'kl_act' or self.aligner_type == 'mi_act' or self.aligner_type == 'mi_kl_act'):
                aligned_output = self.aligner_act_dist(self.aligner(g_messages)).probs
            elif(self.aligner_type == 'obs_act' or self.aligner_type == 'obs_kl_act' or self.aligner_type == 'mi_obs_act' or self.aligner_type == 'mi_obs_kl_act'):
                aligned_output = (self.o_aligner(g_messages), self.aligner_act_dist(self.a_aligner(g_messages)).probs)

        if(self.use_comm_sep_rnn):
            return value, action_log_probs, dist_entropy, messages, rnn_hxs, aligned_output, dist.probs, crnn_hxs
        else:
            return value, action_log_probs, dist_entropy, messages, rnn_hxs, aligned_output, dist.probs, None


class NNBase(nn.Module):
    def __init__(self, recurrent, recurrent_input_size, hidden_size, recurrent_type = 'gru'):
        super(NNBase, self).__init__()
        self._hidden_size = hidden_size
        self._recurrent = recurrent
        self._recurrent_type = recurrent_type

        if recurrent:
            if(recurrent_type == 'gru'):
                self.rnn = nn.GRU(recurrent_input_size, hidden_size)
                self.comm_rnn = nn.GRU(recurrent_input_size, hidden_size)
            # elif(recurrent_type == 'gru_jit'):
            #     self.rnn = JitGRU(recurrent_input_size, hidden_size, 1)
            # elif(recurrent_type == 'qrnn'):
            #     self.rnn = QRNN(recurrent_input_size, hidden_size, use_cuda = False)
            # elif(recurrent_type == 'qrnn_jit'):
            #     self.rnn = JitQRNN(recurrent_input_size, hidden_size, use_cuda = False)
            # elif(recurrent_type == 'qrnn_2'):
            #     self.rnn = QRNNLayer2(recurrent_input_size, hidden_size, kernel_size = 1)
            # elif(recurrent_type == 'mgu'):
            #     self.rnn = MGU(recurrent_input_size, hidden_size)
            # elif(recurrent_type == 'mgu_jit'):
            #     self.rnn = JitMGU(recurrent_input_size, hidden_size, 1)
            # elif(recurrent_type == 'mgu2_jit'):
            #     self.rnn = JitMGU2(recurrent_input_size, hidden_size, 1)
            # elif(recurrent_type == 'lmu'):
            #     self.rnn = LMUCell(input_size = recurrent_input_size, hidden_size = hidden_size, order=256,
            #     input_encoders_initializer=partial(torch.nn.init.constant_, val=1),
            #     hidden_encoders_initializer=partial(torch.nn.init.constant_, val=0),
            #     memory_encoders_initializer=partial(torch.nn.init.constant_, val=0),
            #     input_kernel_initializer=partial(torch.nn.init.constant_, val=0),
            #     hidden_kernel_initializer=partial(torch.nn.init.constant_, val=0),
            #     memory_kernel_initializer=torch.nn.init.xavier_normal_,)
            # elif(recurrent_type == 'drnn'):
            #     self.rnn = DRNN(recurrent_input_size, hidden_size, 1, cell_type = 'GRU', batch_first = True)
            # elif(recurrent_type == 'lstm'):
            #     self.rnn = nn.LSTM(recurrent_input_size, hidden_size)
            else:
                raise("This is not supported")
            for name, param in self.rnn.named_parameters():
                if "bias" in name:
                    nn.init.constant_(param, 0)
                elif "weight" in name:
                    nn.init.orthogonal_(param)

    @property
    def is_recurrent(self):
        return self._recurrent

    @property
    def recurrent_hidden_state_size(self):
        if self._recurrent:
            return self._hidden_size
        return 1

    @property
    def output_size(self):
        return self._hidden_size

    #@profile
    def _forward_gru(self, x, hxs, masks):
        if x.size(0) == hxs.size(0):
            # print( (hxs * masks).unsqueeze(0).unsqueeze(0).size())
            x, hxs = self.rnn(x.unsqueeze(0), (hxs * masks).unsqueeze(0))
            x = x.squeeze(0)
            hxs = hxs.squeeze(0)
        else:
            # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
            N = hxs.size(0)
            T = int(x.size(0) / N)

            # unflatten
            x = x.view(T, N, x.size(1))

            # Same deal with masks
            masks = masks.view(T, N)

            # Let's figure out which steps in the sequence have a zero for any agent
            # We will always assume t=0 has a zero in it as that makes the logic cleaner
            has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu()

            # +1 to correct the masks[1:]
            if has_zeros.dim() == 0:
                # Deal with scalar
                has_zeros = [has_zeros.item() + 1]
            else:
                has_zeros = (has_zeros + 1).numpy().tolist()

            # add t=0 and t=T to the list
            has_zeros = [0] + has_zeros + [T]

            hxs = hxs.unsqueeze(0)
            outputs = []
            for i in range(len(has_zeros) - 1):
                # We can now process steps that don't have any zeros in masks together!
                # This is much faster
                start_idx = has_zeros[i]
                end_idx = has_zeros[i + 1]
                rnn_scores, hxs = self.rnn(x[start_idx:end_idx], hxs * masks[start_idx].view(1, -1, 1))

                outputs.append(rnn_scores)

            # assert len(outputs) == T
            # x is a (T, N, -1) tensor
            x = torch.cat(outputs, dim=0)
            # flatten
            x = x.view(T * N, -1)
            hxs = hxs.squeeze(0)

        return x, hxs

    def _forward_comm_gru(self, x, hxs, masks):
        if x.size(0) == hxs.size(0):
            # print( (hxs * masks).unsqueeze(0).unsqueeze(0).size())
            x, hxs = self.comm_rnn(x.unsqueeze(0), (hxs * masks).unsqueeze(0))
            x = x.squeeze(0)
            hxs = hxs.squeeze(0)
        else:
            # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
            N = hxs.size(0)
            T = int(x.size(0) / N)

            # unflatten
            x = x.view(T, N, x.size(1))

            # Same deal with masks
            masks = masks.view(T, N)

            # Let's figure out which steps in the sequence have a zero for any agent
            # We will always assume t=0 has a zero in it as that makes the logic cleaner
            has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu()

            # +1 to correct the masks[1:]
            if has_zeros.dim() == 0:
                # Deal with scalar
                has_zeros = [has_zeros.item() + 1]
            else:
                has_zeros = (has_zeros + 1).numpy().tolist()

            # add t=0 and t=T to the list
            has_zeros = [0] + has_zeros + [T]

            hxs = hxs.unsqueeze(0)
            outputs = []
            for i in range(len(has_zeros) - 1):
                # We can now process steps that don't have any zeros in masks together!
                # This is much faster
                start_idx = has_zeros[i]
                end_idx = has_zeros[i + 1]
                rnn_scores, hxs = self.comm_rnn(x[start_idx:end_idx], hxs * masks[start_idx].view(1, -1, 1))

                outputs.append(rnn_scores)

            # assert len(outputs) == T
            # x is a (T, N, -1) tensor
            x = torch.cat(outputs, dim=0)
            # flatten
            x = x.view(T * N, -1)
            hxs = hxs.squeeze(0)

        return x, hxs



class MLPBase(NNBase):
    def __init__(self, num_inputs, recurrent=False, hidden_size=32, recurrent_type = 'gru', use_spectral_norm = False, use_conv_state_encoder = False):
        super(MLPBase, self).__init__(recurrent, hidden_size, hidden_size, recurrent_type)

        # state encoder
        # self.state_encoder_networks = JitStateEncoder(num_inputs, hidden_size)
        if(use_conv_state_encoder):
            self.use_conv_state_encoder = ConvStateEncoder(hidden_size)
        else:
            self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)

        if recurrent:
            num_inputs = hidden_size

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # Non-Jit
        # self.actor = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )

        # self.critic = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )

        # self.critic_linear = init_(nn.Linear(hidden_size, 1))

        # Jit
        # self.comm_rl_head_networks = JitRLHeadNetworks(num_inputs, hidden_size)
        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)
        self.train()

    #@profile
    def forward(self, inputs, rnn_hxs, masks):
        x = inputs
        x = self.state_encoder_networks(x)
        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        # Jit
        output_critic, hidden_actor = self.comm_rl_head_networks(x)
        return output_critic, hidden_actor, rnn_hxs

        # Non-Jit
        # hidden_critic = self.critic(x)
        # hidden_actor = self.actor(x)

        # return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs

class MGMLPBase(NNBase):
    def __init__(self, num_inputs, num_direct_features, recurrent=False, hidden_size=32, recurrent_type = 'gru', use_spectral_norm = False, use_conv_state_encoder = False, use_norm = False):
        super(MGMLPBase, self).__init__(recurrent, hidden_size, hidden_size, recurrent_type)

        self.num_direct_features = num_direct_features


        self.img_encoder = MarlGridImgModule(num_inputs, hidden_size)
        self.use_norm = use_norm
        if(self.use_norm):
            self.img_norm = nn.LayerNorm(hidden_size)
        self.state_encoder_networks = StateEncoder(hidden_size + num_direct_features, hidden_size)

        if recurrent:
            num_inputs = hidden_size

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # Non-Jit
        # self.actor = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )

        # self.critic = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )

        # self.critic_linear = init_(nn.Linear(hidden_size, 1))

        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)
        self.train()

    def encode_obs(self, inputs, keep_grad = False):
        if(keep_grad):
            x = self.img_norm(self.img_encoder(inputs[0].permute(0, 3, 1, 2))) if self.use_norm else self.img_encoder(inputs[0].permute(0, 3, 1, 2))
        else:
            with torch.no_grad():
                x = self.img_norm(self.img_encoder(inputs[0].permute(0, 3, 1, 2))) if self.use_norm else self.img_encoder(inputs[0].permute(0, 3, 1, 2))
        # Concat with direct features
        return torch.cat((x, inputs[1]), dim = -1)

    def forward(self, inputs, rnn_hxs, masks):
        x = inputs
        x = self.encode_obs(x, keep_grad = True)
        x = self.state_encoder_networks(x)
        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        output_critic, hidden_actor = self.comm_rl_head_networks(x)
        return output_critic, hidden_actor, rnn_hxs


class CommMLPBase(NNBase):
    def __init__(self, num_inputs, message_space, recurrent=False, hidden_size=32, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False, use_dial = False):
        super(CommMLPBase, self).__init__(recurrent, hidden_size + message_space[1] if use_message_decoder else hidden_size + message_space[0], hidden_size, recurrent_type)
        self.num_comm_inputs = message_space[0]
        self.comm_embed_size = message_space[1]
        self.num_comm_outputs = message_space[2]
        self._use_message_decoder = use_message_decoder
        self.use_dial = use_dial

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # state encoder
        # self.state_encoder_networks = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        # )
        # state encoder
        # self.state_encoder_networks = JitStateEncoder(num_inputs, hidden_size)
        self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)


        if(recurrent):
            num_inputs = hidden_size
        else:
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        # Message decoder - to be concatenated with observation
        if(use_message_decoder):
            self.message_decoder = nn.Sequential(
                init_(nn.Linear(self.num_comm_inputs, self.comm_embed_size)),
                nn.ReLU()
            )


        # Non Jit
        # Comm actor - to produce message
        # self.comm_actor = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, self.num_comm_outputs)),
        #     nn.Sigmoid(),
        # )
        # self.actor = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )

        # self.critic = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )
        # self.critic_linear = init_(nn.Linear(hidden_size, 1))

        # Jit
        # self.comm_rl_head_networks = JitRLHeadNetworks(num_inputs, hidden_size)
        # self.comm_comm_head_networks = JitCommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)
        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)
        self.comm_comm_head_networks = CommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs, use_spectral_norm)

        self.GRUS = GRUS
        self.LSTMS = LSTMS

        self.train()

    #@profile
    # @jit.script_method

    def encode_obs(self, inputs):
        return self.state_encoder_networks(inputs)

    def forward(self, inputs, message_inputs, rnn_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        x = self.state_encoder_networks(x)

        m = message_inputs
        # # create a detached copy for aligner so these gradients do not get backpropagated across agents
        # g_m = message_inputs.clone().detach()

        if(self._use_message_decoder):
            m = self.message_decoder(m)

        x = torch.cat((x, m), 1)

        if self.is_recurrent:
            (x, rnn_hxs) = self._forward_gru(x, rnn_hxs, masks)

        # Jit
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        # The detach here is to avoid the representation learning of the message space from affecting RL
        comm_actor_m = self.comm_comm_head_networks(x.clone().detach())
        g_comm_actor_m = self.comm_comm_head_networks(x.clone().detach())
        # The detach of comm_actor_m here is to avoid DIAL so I can completely separate representation learning of the message space from RL
        if(self.use_dial):
            return output_critic, hidden_actor, (comm_actor_m, g_comm_actor_m), rnn_hxs
        else:
            return output_critic, hidden_actor, (comm_actor_m.detach(), g_comm_actor_m), rnn_hxs
        # return output_critic, hidden_actor, (comm_actor_m, g_comm_actor_m), rnn_hxs

        # # Non Jit
        # if(comm_partial_with_grad):
        #     with torch.no_grad():
        #         hidden_critic = self.critic_linear(self.critic(x))
        #         hidden_actor = self.actor(x)
        #     comm_actor_m = self.comm_actor(x)
        #     return hidden_critic, hidden_actor, comm_actor_m, rnn_hxs
        # else:
        #     hidden_critic = self.critic(x)
        #     hidden_actor = self.actor(x)
        #     comm_actor_m = self.comm_actor(x)

        #     return self.critic_linear(hidden_critic), hidden_actor, comm_actor_m, rnn_hxs


class CommAltMLPBase(NNBase):
    def __init__(self, num_inputs, message_space, recurrent=False, hidden_size=32, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False, use_dial = False, use_conv_state_encoder = False, aligner_type = None):
        super(CommAltMLPBase, self).__init__(recurrent, hidden_size + message_space[1] if use_message_decoder else hidden_size + message_space[0], hidden_size, recurrent_type)
        self.num_comm_inputs = message_space[0]
        self.comm_embed_size = message_space[1]
        self.num_comm_outputs = message_space[2]
        self._use_message_decoder = use_message_decoder
        self.use_dial = use_dial


        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # state encoder
        # self.state_encoder_networks = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        # )
        # state encoder
        # self.state_encoder_networks = JitStateEncoder(num_inputs, hidden_size)
        self.img_encoder = None
        if(use_conv_state_encoder):
            # With simple ConvNets
            # self.state_encoder_networks = ConvStateEncoder(hidden_size)
            # With ResNet
            self.img_encoder, num_inputs = load_simclr_encoder()
            self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)
            if(aligner_type == 'mi_split'):
                self.comm_comm_head_networks = CommSplitHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs, use_spectral_norm)
            else:
                self.comm_comm_head_networks = CommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs, use_spectral_norm)
        else:
            self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)
            if(aligner_type == 'mi_split'):
                self.comm_comm_head_networks = CommSplitHeadNetworks(hidden_size, hidden_size, self.num_comm_outputs, use_spectral_norm)
            else:
                self.comm_comm_head_networks = CommHeadNetworks(hidden_size, hidden_size, self.num_comm_outputs, use_spectral_norm)

        num_inputs = hidden_size
        if(recurrent == False):
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        # Message decoder - to be concatenated with observation
        # if(use_message_decoder):
        #     self.message_decoder = nn.Sequential(
        #         init_(nn.Linear(self.num_comm_inputs, hidden_size)),
        #         nn.ReLU(),
        #         init_(nn.Linear(hidden_size, hidden_size)),
        #         nn.ReLU(),
        #         init_(nn.Linear(hidden_size, self.comm_embed_size)),
        #         nn.ReLU()
        #     )
        if(use_message_decoder):
            # Use for PP
            self.message_decoder = nn.Sequential(
                init_(nn.Linear(self.num_comm_inputs, self.comm_embed_size)),
                nn.ReLU()
            )
            # Use for TJ
            # self.message_decoder = nn.Sequential(
            #     init_(nn.Linear(self.num_comm_inputs, hidden_size)),
            #     nn.ReLU(),
            #     init_(nn.Linear(hidden_size, hidden_size)),
            #     nn.ReLU(),
            #     init_(nn.Linear(hidden_size, self.comm_embed_size)),
            #     nn.ReLU()
            # )

        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)

        self.GRUS = GRUS
        self.LSTMS = LSTMS

        self.train()

    def encode_obs(self, inputs):
        if(self.img_encoder != None):
            with torch.no_grad():
                inputs = self.img_encoder(inputs)
        return self.state_encoder_networks(inputs)
        # return inputs

    #@profile
    # @jit.script_method
    def forward(self, inputs, message_inputs, rnn_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        if(self.img_encoder != None):
            with torch.no_grad():
                x = self.img_encoder(x)
            comm_actor_m, g_comm_actor_m = self.comm_comm_head_networks(x.clone().detach()), self.comm_comm_head_networks(x.clone().detach())
            x = self.state_encoder_networks(x)
        else:
            x = self.state_encoder_networks(x)
            comm_actor_m = self.comm_comm_head_networks(x.clone().detach())
            g_comm_actor_m = self.comm_comm_head_networks(x.clone().detach())
        m = message_inputs
        if(self._use_message_decoder):
            m = self.message_decoder(message_inputs)
        x = torch.cat((x, m), 1)
        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        # Jit
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        if(self.use_dial):
            return output_critic, hidden_actor, (comm_actor_m, g_comm_actor_m), rnn_hxs
        else:
            return output_critic, hidden_actor, (comm_actor_m.detach(), g_comm_actor_m), rnn_hxs

# Architecture for MARLGrid inputs, it has direct features that don't go through CNNs
class CommMGBase(NNBase):
    def __init__(self, num_inputs, num_direct_features, message_space, recurrent=False, hidden_size=32, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False, use_dial = False, use_conv_state_encoder = False, aligner_type = None, use_norm = False):
        super(CommMGBase, self).__init__(recurrent, hidden_size + message_space[1] if use_message_decoder else hidden_size + message_space[0], hidden_size, recurrent_type)
        self.num_comm_inputs = message_space[0]
        self.comm_embed_size = message_space[1]
        self.num_comm_outputs = message_space[2]
        self._use_message_decoder = use_message_decoder
        self.num_direct_features = num_direct_features
        self.use_dial = use_dial


        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        self.img_encoder = MarlGridImgModule(num_inputs, hidden_size)
        self.state_encoder_networks = StateEncoder(hidden_size + num_direct_features, hidden_size)
        self.use_norm = use_norm
        if(self.use_norm):
            self.img_norm = nn.LayerNorm(hidden_size)
        if(aligner_type == 'mi_split'):
            self.comm_comm_head_networks = CommSplitHeadNetworks(hidden_size + num_direct_features, hidden_size, self.num_comm_outputs, use_spectral_norm)
        else:
            self.comm_comm_head_networks = CommHeadNetworks(hidden_size + num_direct_features, hidden_size, self.num_comm_outputs, use_spectral_norm)

        num_inputs = hidden_size
        if(recurrent == False):
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        # Message decoder - to be concatenated with observation
        if(use_message_decoder):
            self.message_decoder = nn.Sequential(
                init_(nn.Linear(self.num_comm_inputs, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, self.comm_embed_size)),
                nn.ReLU()
            )

        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)

        self.GRUS = GRUS
        self.LSTMS = LSTMS

        self.train()

    def encode_obs(self, inputs, keep_grad = False):
        if(keep_grad):
            x = self.img_norm(self.img_encoder(inputs[0].permute(0, 3, 1, 2))) if self.use_norm else self.img_encoder(inputs[0].permute(0, 3, 1, 2))
        else:
            with torch.no_grad():
                x = self.img_norm(self.img_encoder(inputs[0].permute(0, 3, 1, 2))) if self.use_norm else self.img_encoder(inputs[0].permute(0, 3, 1, 2))
        # Concat with direct features
        return torch.cat((x, inputs[1]), dim = -1)

    def forward(self, inputs, message_inputs, rnn_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        x = self.encode_obs(x, keep_grad = True)
        comm_actor_m = self.comm_comm_head_networks(x.clone().detach())
        g_comm_actor_m = self.comm_comm_head_networks(x.clone().detach())
        x = self.state_encoder_networks(x)
        m = message_inputs
        if(self._use_message_decoder):
            m = self.message_decoder(message_inputs)
        x = torch.cat((x, m), 1)
        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        # Jit
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        if(self.use_dial):
            return output_critic, hidden_actor, (comm_actor_m, g_comm_actor_m), rnn_hxs
        else:
            return output_critic, hidden_actor, (comm_actor_m.detach(), g_comm_actor_m), rnn_hxs

# Communication uses separate RNN
class CommMLPBase_CSRNN(NNBase):
    def __init__(self, num_inputs, message_space, recurrent=False, hidden_size=32, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False):
        super(CommMLPBase_CSRNN, self).__init__(recurrent, hidden_size + message_space[1] if use_message_decoder else hidden_size + message_space[0], hidden_size, recurrent_type)
        self.num_comm_inputs = message_space[0]
        self.comm_embed_size = message_space[1]
        self.num_comm_outputs = message_space[2]
        self._use_message_decoder = use_message_decoder


        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # state encoder
        self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)


        if(recurrent):
            num_inputs = hidden_size
        else:
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        # Message decoder - to be concatenated with observation
        if(use_message_decoder):
            self.message_decoder = nn.Sequential(
                init_(nn.Linear(self.num_comm_inputs, self.comm_embed_size)),
                nn.ReLU()
            )

        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)
        self.comm_comm_head_networks = CommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)

        self.GRUS = GRUS
        self.LSTMS = LSTMS

        self.train()

    #@profile
    # @jit.script_method
    def forward(self, inputs, message_inputs, rnn_hxs, comm_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        x = self.state_encoder_networks(x)

        m = message_inputs
        if(self._use_message_decoder):
            m = self.message_decoder(message_inputs)
        x = torch.cat((x, m), 1)
        if self.is_recurrent:
            # detach and comm gru to avoid gradients from backpropagating to state encoder and message decoder
            (x, rnn_hxs), (c_x, c_rnn_hxs) = self._forward_gru(x, rnn_hxs, masks), self._forward_comm_gru(x.clone().detach(), comm_hxs, masks)
            # c_x, c_rnn_hxs = self._forward_comm_gru(x, comm_hxs, masks)

        # Jit
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        # comm_actor_m = self.comm_comm_head_networks(c_x)
        comm_actor_m = self.comm_comm_head_networks(c_x.clone().detach())
        # It is not detached here because gradients can still go through comm_gru
        g_comm_actor_m = self.comm_comm_head_networks(c_x.clone())
        return output_critic, hidden_actor, (comm_actor_m.detach(), g_comm_actor_m), rnn_hxs, c_rnn_hxs

class CommGatedMLPBase(CommMLPBase):
    def __init__(self, num_inputs, message_space, recurrent=False, hidden_size=32, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False):
        super(CommGatedMLPBase, self).__init__(num_inputs, message_space, recurrent, hidden_size, use_message_decoder, recurrent_type)

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        if(recurrent):
            num_inputs = hidden_size
        else:
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        # jit
        # self.comm_comm_head_networks = JitCommGatedeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)
        self.comm_comm_head_networks = CommGatedeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)

        # Non jit
        # self.comm_gate = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, 2)),
        #     nn.LogSoftmax(dim = -1),
        # )

    # @jit.script_method
    #@profile
    def forward(self, inputs, message_inputs, rnn_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        x = self.state_encoder_networks(x)
        m = message_inputs
        if(self._use_message_decoder):
            m = self.message_decoder(message_inputs)

        x = torch.cat((x, m), 1)

        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        # Jit - 7:40 for 1000 updates
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        comm_actor_m = self.comm_comm_head_networks(x)

        # Jit threaded - not much faster - 7:38 for 1000 updates
        # rl_head_thread = ThreadWithReturnValue(target = self.comm_rl_head_networks, args=(x, ))
        # rl_head_thread.start()
        # comm_head_thread = ThreadWithReturnValue(target = self.comm_comm_head_networks, args=(x, ))
        # comm_head_thread.start()
        # output_critic, hidden_actor = rl_head_thread.join()
        # comm_actor_m = comm_head_thread.join()

        return output_critic, hidden_actor, comm_actor_m, rnn_hxs

        # # Non Jit
        # if(comm_partial_with_grad):
        #     with torch.no_grad():
        #         hidden_critic = self.critic_linear(self.critic(x))
        #         hidden_actor = self.actor(x)
        #     comm_actor_m = self.comm_actor(x)
        #     return hidden_critic, hidden_actor, comm_actor_m, rnn_hxs
        # else:
        #     hidden_critic = self.critic(x)
        #     hidden_actor = self.actor(x)
        #     comm_actor_m = self.comm_actor(x)

        #     return self.critic_linear(hidden_critic), hidden_actor, comm_actor_m, rnn_hxs


# DND
class CommDNDMLPBase(NNBase):
    def __init__(self, num_inputs, message_space, recurrent=False, hidden_size=32, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False):
        super(CommDNDMLPBase, self).__init__(recurrent, hidden_size + message_space[1] if use_message_decoder else hidden_size + message_space[0], hidden_size, recurrent_type)
        self.num_comm_inputs = message_space[0]
        self.comm_embed_size = message_space[1]
        self.num_comm_outputs = message_space[2]
        self._use_message_decoder = use_message_decoder

        # state encoder
        # self.state_encoder_networks = JitStateEncoder(num_inputs, hidden_size)
        self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)

        if(recurrent):
            num_inputs = hidden_size
        else:
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # Message decoder - to be concatenated with observation
        if(use_message_decoder):
            self.message_decoder = nn.Sequential(
                init_(nn.Linear(self.num_comm_inputs, self.comm_embed_size)),
                nn.ReLU()
            )

        # Non Jit
        # Comm actor - to produce message
        # self.comm_actor = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, self.num_comm_outputs)),
        #     nn.Sigmoid(),
        # )
        # self.actor = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )

        # self.critic = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        # )
        # self.critic_linear = init_(nn.Linear(hidden_size, 1))

        # Jit
        # self.comm_rl_head_networks = JitRLHeadNetworks(num_inputs, hidden_size)
        # self.comm_comm_head_networks = JitCommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)
        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)
        self.comm_comm_head_networks = CommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)


        self.train()

    def get_compressed_obs(self, inputs):
        x = inputs
        return self.state_encoder_networks(x)

    #@profile
    # @jit.script_method
    def forward(self, inputs, message_inputs, rnn_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        m = message_inputs
        if(self._use_message_decoder):
            m = self.message_decoder(message_inputs)
        x = torch.cat((x, m), 1)

        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        # Jit
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        comm_actor_m = self.comm_comm_head_networks(x)

        key = None
        if(x.size(0) == rnn_hxs.size(0)):
            key = torch.cat((x, rnn_hxs), dim = 1)

        return output_critic, hidden_actor, comm_actor_m, rnn_hxs, key

        # # Non Jit
        # if(comm_partial_with_grad):
        #     with torch.no_grad():
        #         hidden_critic = self.critic_linear(self.critic(x))
        #         hidden_actor = self.actor(x)
        #     comm_actor_m = self.comm_actor(x)
        #     return hidden_critic, hidden_actor, comm_actor_m, rnn_hxs
        # else:
        #     hidden_critic = self.critic(x)
        #     hidden_actor = self.actor(x)
        #     comm_actor_m = self.comm_actor(x)

        #     return self.critic_linear(hidden_critic), hidden_actor, comm_actor_m, rnn_hxs

# Shared DND with separate key genertors
class CommSDNDMLPBase(NNBase):
    def __init__(self, num_inputs, message_space, recurrent=False, hidden_size=32, key_size = 6, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False):
        super(CommSDNDMLPBase, self).__init__(recurrent, hidden_size + message_space[1] if use_message_decoder else hidden_size + message_space[0], hidden_size, recurrent_type)
        self.num_comm_inputs = message_space[0]
        self.comm_embed_size = message_space[1]
        self.num_comm_outputs = message_space[2]
        self._use_message_decoder = use_message_decoder

        # state encoder
        self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)

        if(recurrent):
            num_inputs = hidden_size
        else:
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # Message decoder - to be concatenated with observation
        if(use_message_decoder):
            self.message_decoder = nn.Sequential(
                init_(nn.Linear(self.num_comm_inputs, self.comm_embed_size)),
                nn.ReLU()
            )

        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)
        self.comm_comm_head_networks = CommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)

        # Shared DND generators
        self.mem_read_key_generator = DNDKeyGenerator(hidden_size * 2, key_size)
        self.mem_write_key_generator = DNDKeyGenerator(hidden_size * 2, key_size)

        self.train()

    def get_compressed_obs(self, inputs):
        x = inputs
        return self.state_encoder_networks(x)

    def get_read_keys(self, inputs, hiddens):
        # Generate read keys
        context = torch.cat((inputs, hiddens), dim = 1)
        read_keys = self.mem_read_key_generator(context)
        return read_keys

    def get_write_keys(self, inputs, hiddens):
        # Generate write keys
        context = torch.cat((inputs, hiddens), dim = 1)
        write_keys = self.mem_write_key_generator(context)
        return write_keys

    #@profile
    def forward(self, inputs, message_inputs, rnn_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        m = message_inputs
        if(self._use_message_decoder):
            m = self.message_decoder(message_inputs)
        x = torch.cat((x, m), 1)
        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        # Jit
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        comm_actor_m = self.comm_comm_head_networks(x)

        return output_critic, hidden_actor, comm_actor_m, rnn_hxs


# Linear Memory
class CommLinearMemMLPBase(NNBase):
    def __init__(self, num_inputs, message_space, mem_buffer_size, recurrent=False, hidden_size=32, use_message_decoder = True, recurrent_type = 'gru', use_spectral_norm = False):
        super(CommLinearMemMLPBase, self).__init__(recurrent, hidden_size + message_space[1] if use_message_decoder else hidden_size + message_space[0], hidden_size, recurrent_type)
        self.num_comm_inputs = message_space[0]
        self.comm_embed_size = message_space[1]
        self.num_comm_outputs = message_space[2]
        self._use_message_decoder = use_message_decoder

        # state encoder
        # self.state_encoder_networks = JitStateEncoder(num_inputs, hidden_size)
        self.state_encoder_networks = StateEncoder(num_inputs, hidden_size)

        # linear memory reader
        # self.mem_reader_networks = JitLinMemReader(hidden_size * 2, hidden_size, mem_buffer_size)
        self.mem_reader_networks = LinMemReader(hidden_size * 2, hidden_size, mem_buffer_size, self.num_comm_outputs)

        if(recurrent):
            num_inputs = hidden_size
        else:
            if(use_message_decoder):
                num_inputs += self.comm_embed_size
            else:
                num_inputs += self.num_comm_inputs

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        # Message decoder - to be concatenated with observation
        if(use_message_decoder):
            self.message_decoder = nn.Sequential(
                init_(nn.Linear(self.num_comm_inputs, self.comm_embed_size)),
                nn.ReLU()
            )


        # Jit
        # self.comm_rl_head_networks = JitRLHeadNetworks(num_inputs, hidden_size)
        # self.comm_comm_head_networks = JitCommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)
        self.comm_rl_head_networks = RLHeadNetworks(num_inputs, hidden_size, use_spectral_norm)
        self.comm_comm_head_networks = CommHeadNetworks(num_inputs, hidden_size, self.num_comm_outputs)

        self.train()

    def get_compressed_obs(self, inputs):
        x = inputs
        return self.state_encoder_networks(x)

    def read_messages_from_memory(self, encoded_inputs, memory_content):
        mem_mask = self.mem_reader_networks(encoded_inputs, memory_content).unsqueeze(2).repeat(1, 1, memory_content.size(2))
        retrieved_msg = torch.sum(mem_mask * memory_content, dim = 1)
        return retrieved_msg

    #@profile
    # @jit.script_method
    def forward(self, inputs, message_inputs, rnn_hxs, masks, comm_partial_with_grad = False):
        x = inputs
        m = message_inputs
        if(self._use_message_decoder):
            m = self.message_decoder(message_inputs)
        x = torch.cat((x, m), 1)

        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)
        # Jit
        with torch.set_grad_enabled(not comm_partial_with_grad):
            output_critic, hidden_actor = self.comm_rl_head_networks(x)
        comm_actor_m = self.comm_comm_head_networks(x)

        return output_critic, hidden_actor, comm_actor_m, rnn_hxs


def mg_weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        weight_shape = list(m.weight.data.size())
        fan_in = np.prod(weight_shape[1:4])
        fan_out = np.prod(weight_shape[2:4]) * weight_shape[0]
        w_bound = np.sqrt(6. / (fan_in + fan_out))
        m.weight.data.uniform_(-w_bound, w_bound)
        m.bias.data.fill_(0)

    elif classname.find('Linear') != -1:
        weight_shape = list(m.weight.data.size())
        fan_in = weight_shape[1]
        fan_out = weight_shape[0]
        w_bound = np.sqrt(6. / (fan_in + fan_out))
        m.weight.data.uniform_(-w_bound, w_bound)
        m.bias.data.fill_(0)

class MarlGridImgModule(nn.Module):
    """Process image inputs of shape CxHxW."""
    def __init__(self, input_size, last_fc_dim=0):
        super(MarlGridImgModule, self).__init__()
        self.conv1 = self.make_layer(input_size[2], 32)
        self.conv2 = self.make_layer(32, 32)
        self.conv3 = self.make_layer(32, 32)
        self.conv4 = self.make_layer(32, 32)
        self.avgpool = nn.AdaptiveAvgPool2d([3, 3])
        if last_fc_dim > 0:
            self.fc = nn.Linear(288, last_fc_dim)
        else:
            self.fc = None
        self.apply(mg_weights_init)

    def make_layer(self, in_ch, out_ch, use_norm=True):
        layer = []
        layer += [nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)]
        layer += [nn.ELU(True)]
        if use_norm:
            layer += [nn.InstanceNorm2d(out_ch)]
        return nn.Sequential(*layer)

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.avgpool(x)
        x = x.view(-1, 32 * 3 * 3)  # feature dim
        if self.fc is not None:
            return self.fc(x)
        return x
